{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.0.0-alpha0\n",
      "sys.version_info(major=3, minor=5, micro=3, releaselevel='final', serial=0)\n",
      "matplotlib 3.0.3\n",
      "numpy 1.16.2\n",
      "pandas 0.24.2\n",
      "sklearn 0.21.2\n",
      "tensorflow 2.0.0-alpha0\n",
      "tensorflow.python.keras.api._v2.keras 2.2.4-tf\n"
     ]
    }
   ],
   "source": [
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import sklearn\n",
    "import pandas as pd\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "import tensorflow as tf\n",
    "\n",
    "from tensorflow import keras\n",
    "\n",
    "print(tf.__version__)\n",
    "print(sys.version_info)\n",
    "for module in mpl, np, pd, sklearn, tf, keras:\n",
    "    print(module.__name__, module.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1115394\n",
      "First Citizen:\n",
      "Before we proceed any further, hear me speak.\n",
      "\n",
      "All:\n",
      "Speak, speak.\n",
      "\n",
      "First Citizen:\n",
      "You\n"
     ]
    }
   ],
   "source": [
    "# https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt\n",
    "input_filepath = \"./shakespeare.txt\"\n",
    "text = open(input_filepath, 'r').read()\n",
    "\n",
    "print(len(text))\n",
    "print(text[0:100])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "65\n",
      "['\\n', ' ', '!', '$', '&', \"'\", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']\n"
     ]
    }
   ],
   "source": [
    "# 1. generate vocab\n",
    "# 2. build mapping char->id\n",
    "# 3. data -> id_data\n",
    "# 4. abcd -> bcd<eos>\n",
    "\n",
    "vocab = sorted(set(text))\n",
    "print(len(vocab))\n",
    "print(vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'x': 62, 'Z': 38, 'F': 18, 'T': 32, 'n': 52, 'A': 13, '3': 9, 'b': 40, 'C': 15, 'e': 43, 'q': 55, 'K': 23, 'z': 64, 'E': 17, \"'\": 5, 'Y': 37, 'd': 42, 'W': 35, 'S': 31, 'p': 54, 'I': 21, 'J': 22, 'y': 63, 'X': 36, 'G': 19, '$': 3, 'N': 26, 'w': 61, 'i': 47, ' ': 1, 'h': 46, 'j': 48, 'o': 53, ',': 6, '-': 7, 'H': 20, 's': 57, 'U': 33, '!': 2, 'g': 45, 'v': 60, ':': 10, 'r': 56, 'u': 59, 'V': 34, 'O': 27, 'P': 28, ';': 11, 'L': 24, '.': 8, 'B': 14, '\\n': 0, 'k': 49, 'Q': 29, 'c': 41, 't': 58, 'M': 25, 'R': 30, 'f': 44, 'm': 51, 'l': 50, '&': 4, '?': 12, 'a': 39, 'D': 16}\n"
     ]
    }
   ],
   "source": [
    "char2idx = {char:idx for idx, char in enumerate(vocab)}\n",
    "print(char2idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['\\n' ' ' '!' '$' '&' \"'\" ',' '-' '.' '3' ':' ';' '?' 'A' 'B' 'C' 'D' 'E'\n",
      " 'F' 'G' 'H' 'I' 'J' 'K' 'L' 'M' 'N' 'O' 'P' 'Q' 'R' 'S' 'T' 'U' 'V' 'W'\n",
      " 'X' 'Y' 'Z' 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j' 'k' 'l' 'm' 'n' 'o'\n",
      " 'p' 'q' 'r' 's' 't' 'u' 'v' 'w' 'x' 'y' 'z']\n"
     ]
    }
   ],
   "source": [
    "idx2char = np.array(vocab)\n",
    "print(idx2char)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[18 47 56 57 58  1 15 47 58 47]\n",
      "First Citi\n"
     ]
    }
   ],
   "source": [
    "text_as_int = np.array([char2idx[c] for c in text])\n",
    "print(text_as_int[0:10])\n",
    "print(text[0:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(18, shape=(), dtype=int64) F\n",
      "tf.Tensor(47, shape=(), dtype=int64) i\n",
      "tf.Tensor(\n",
      "[18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 14 43 44 53 56 43  1 61 43\n",
      "  1 54 56 53 41 43 43 42  1 39 52 63  1 44 59 56 58 46 43 56  6  1 46 43\n",
      " 39 56  1 51 43  1 57 54 43 39 49  8  0  0 13 50 50 10  0 31 54 43 39 49\n",
      "  6  1 57 54 43 39 49  8  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10\n",
      "  0 37 53 59  1], shape=(101,), dtype=int64)\n",
      "'First Citizen:\\nBefore we proceed any further, hear me speak.\\n\\nAll:\\nSpeak, speak.\\n\\nFirst Citizen:\\nYou '\n",
      "tf.Tensor(\n",
      "[39 56 43  1 39 50 50  1 56 43 57 53 50 60 43 42  1 56 39 58 46 43 56  1\n",
      " 58 53  1 42 47 43  1 58 46 39 52  1 58 53  1 44 39 51 47 57 46 12  0  0\n",
      " 13 50 50 10  0 30 43 57 53 50 60 43 42  8  1 56 43 57 53 50 60 43 42  8\n",
      "  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 18 47 56 57 58  6  1\n",
      " 63 53 59  1 49], shape=(101,), dtype=int64)\n",
      "'are all resolved rather to die than to famish?\\n\\nAll:\\nResolved. resolved.\\n\\nFirst Citizen:\\nFirst, you k'\n"
     ]
    }
   ],
   "source": [
    "def split_input_target(id_text):\n",
    "    \"\"\"\n",
    "    abcde -> abcd, bcde\n",
    "    \"\"\"\n",
    "    return id_text[0:-1], id_text[1:]\n",
    "\n",
    "char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)\n",
    "seq_length = 100\n",
    "seq_dataset = char_dataset.batch(seq_length + 1,\n",
    "                                 drop_remainder = True)\n",
    "for ch_id in char_dataset.take(2):\n",
    "    print(ch_id, idx2char[ch_id.numpy()])\n",
    "\n",
    "for seq_id in seq_dataset.take(2):\n",
    "    print(seq_id)\n",
    "    print(repr(''.join(idx2char[seq_id.numpy()])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 14 43 44 53 56 43  1 61 43\n",
      "  1 54 56 53 41 43 43 42  1 39 52 63  1 44 59 56 58 46 43 56  6  1 46 43\n",
      " 39 56  1 51 43  1 57 54 43 39 49  8  0  0 13 50 50 10  0 31 54 43 39 49\n",
      "  6  1 57 54 43 39 49  8  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10\n",
      "  0 37 53 59]\n",
      "[47 56 57 58  1 15 47 58 47 64 43 52 10  0 14 43 44 53 56 43  1 61 43  1\n",
      " 54 56 53 41 43 43 42  1 39 52 63  1 44 59 56 58 46 43 56  6  1 46 43 39\n",
      " 56  1 51 43  1 57 54 43 39 49  8  0  0 13 50 50 10  0 31 54 43 39 49  6\n",
      "  1 57 54 43 39 49  8  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10  0\n",
      " 37 53 59  1]\n",
      "[39 56 43  1 39 50 50  1 56 43 57 53 50 60 43 42  1 56 39 58 46 43 56  1\n",
      " 58 53  1 42 47 43  1 58 46 39 52  1 58 53  1 44 39 51 47 57 46 12  0  0\n",
      " 13 50 50 10  0 30 43 57 53 50 60 43 42  8  1 56 43 57 53 50 60 43 42  8\n",
      "  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 18 47 56 57 58  6  1\n",
      " 63 53 59  1]\n",
      "[56 43  1 39 50 50  1 56 43 57 53 50 60 43 42  1 56 39 58 46 43 56  1 58\n",
      " 53  1 42 47 43  1 58 46 39 52  1 58 53  1 44 39 51 47 57 46 12  0  0 13\n",
      " 50 50 10  0 30 43 57 53 50 60 43 42  8  1 56 43 57 53 50 60 43 42  8  0\n",
      "  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 18 47 56 57 58  6  1 63\n",
      " 53 59  1 49]\n"
     ]
    }
   ],
   "source": [
    "seq_dataset = seq_dataset.map(split_input_target)\n",
    "\n",
    "for item_input, item_output in seq_dataset.take(2):\n",
    "    print(item_input.numpy())\n",
    "    print(item_output.numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 64\n",
    "buffer_size = 10000\n",
    "\n",
    "seq_dataset = seq_dataset.shuffle(buffer_size).batch(\n",
    "    batch_size, drop_remainder=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "embedding (Embedding)        (64, None, 256)           16640     \n",
      "_________________________________________________________________\n",
      "simple_rnn (SimpleRNN)       (64, None, 1024)          1311744   \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (64, None, 65)            66625     \n",
      "=================================================================\n",
      "Total params: 1,395,009\n",
      "Trainable params: 1,395,009\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "vocab_size = len(vocab)\n",
    "embedding_dim = 256\n",
    "rnn_units = 1024\n",
    "\n",
    "def build_model(vocab_size, embedding_dim, rnn_units, batch_size):\n",
    "    model = keras.models.Sequential([\n",
    "        keras.layers.Embedding(vocab_size, embedding_dim,\n",
    "                               batch_input_shape = [batch_size, None]),\n",
    "        keras.layers.SimpleRNN(units = rnn_units,\n",
    "                               stateful = True,\n",
    "                               recurrent_initializer = 'glorot_uniform',\n",
    "                               return_sequences = True),\n",
    "        keras.layers.Dense(vocab_size),\n",
    "    ])\n",
    "    return model\n",
    "\n",
    "model = build_model(\n",
    "    vocab_size = vocab_size,\n",
    "    embedding_dim = embedding_dim,\n",
    "    rnn_units = rnn_units,\n",
    "    batch_size = batch_size)\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(64, 100, 65)\n"
     ]
    }
   ],
   "source": [
    "for input_example_batch, target_example_batch in seq_dataset.take(1):\n",
    "    example_batch_predictions = model(input_example_batch)\n",
    "    print(example_batch_predictions.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(\n",
      "[[18]\n",
      " [22]\n",
      " [22]\n",
      " [41]\n",
      " [24]\n",
      " [48]\n",
      " [42]\n",
      " [44]\n",
      " [37]\n",
      " [45]\n",
      " [ 6]\n",
      " [48]\n",
      " [ 6]\n",
      " [26]\n",
      " [36]\n",
      " [61]\n",
      " [ 8]\n",
      " [22]\n",
      " [63]\n",
      " [41]\n",
      " [20]\n",
      " [14]\n",
      " [20]\n",
      " [35]\n",
      " [57]\n",
      " [ 4]\n",
      " [61]\n",
      " [10]\n",
      " [10]\n",
      " [39]\n",
      " [34]\n",
      " [26]\n",
      " [43]\n",
      " [52]\n",
      " [29]\n",
      " [29]\n",
      " [43]\n",
      " [48]\n",
      " [38]\n",
      " [11]\n",
      " [13]\n",
      " [ 1]\n",
      " [64]\n",
      " [64]\n",
      " [45]\n",
      " [52]\n",
      " [45]\n",
      " [21]\n",
      " [ 5]\n",
      " [15]\n",
      " [52]\n",
      " [35]\n",
      " [16]\n",
      " [34]\n",
      " [ 4]\n",
      " [24]\n",
      " [48]\n",
      " [ 4]\n",
      " [27]\n",
      " [19]\n",
      " [39]\n",
      " [27]\n",
      " [30]\n",
      " [43]\n",
      " [16]\n",
      " [60]\n",
      " [42]\n",
      " [29]\n",
      " [42]\n",
      " [43]\n",
      " [ 1]\n",
      " [60]\n",
      " [ 9]\n",
      " [11]\n",
      " [37]\n",
      " [37]\n",
      " [36]\n",
      " [51]\n",
      " [16]\n",
      " [28]\n",
      " [60]\n",
      " [47]\n",
      " [ 8]\n",
      " [57]\n",
      " [27]\n",
      " [14]\n",
      " [16]\n",
      " [17]\n",
      " [47]\n",
      " [56]\n",
      " [18]\n",
      " [21]\n",
      " [52]\n",
      " [34]\n",
      " [22]\n",
      " [19]\n",
      " [13]\n",
      " [60]\n",
      " [21]\n",
      " [47]], shape=(100, 1), dtype=int64)\n",
      "tf.Tensor(\n",
      "[18 22 22 41 24 48 42 44 37 45  6 48  6 26 36 61  8 22 63 41 20 14 20 35\n",
      " 57  4 61 10 10 39 34 26 43 52 29 29 43 48 38 11 13  1 64 64 45 52 45 21\n",
      "  5 15 52 35 16 34  4 24 48  4 27 19 39 27 30 43 16 60 42 29 42 43  1 60\n",
      "  9 11 37 37 36 51 16 28 60 47  8 57 27 14 16 17 47 56 18 21 52 34 22 19\n",
      " 13 60 21 47], shape=(100,), dtype=int64)\n"
     ]
    }
   ],
   "source": [
    "# random sampling.\n",
    "# greedy, random.\n",
    "sample_indices = tf.random.categorical(\n",
    "    logits = example_batch_predictions[0], num_samples = 1)\n",
    "print(sample_indices)\n",
    "# (100, 65) -> (100, 1)\n",
    "sample_indices = tf.squeeze(sample_indices, axis = -1)\n",
    "print(sample_indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input:  \"'ertake his bad intent,\\nAnd must be buried but as an intent\\nThat perish'd by the way: thoughts are n\"\n",
      "\n",
      "Output:  \"ertake his bad intent,\\nAnd must be buried but as an intent\\nThat perish'd by the way: thoughts are no\"\n",
      "\n",
      "Predictions:  \"FJJcLjdfYg,j,NXw.JycHBHWs&w::aVNenQQejZ;A zzgngI'CnWDV&Lj&OGaOReDvdQde v3;YYXmDPvi.sOBDEirFInVJGAvIi\"\n"
     ]
    }
   ],
   "source": [
    "print(\"Input: \", repr(\"\".join(idx2char[input_example_batch[0]])))\n",
    "print()\n",
    "print(\"Output: \", repr(\"\".join(idx2char[target_example_batch[0]])))\n",
    "print()\n",
    "print(\"Predictions: \", repr(\"\".join(idx2char[sample_indices])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(64, 100)\n",
      "4.1913624\n"
     ]
    }
   ],
   "source": [
    "def loss(labels, logits):\n",
    "    return keras.losses.sparse_categorical_crossentropy(\n",
    "        labels, logits, from_logits=True)\n",
    "\n",
    "model.compile(optimizer = 'adam', loss = loss)\n",
    "example_loss = loss(target_example_batch, example_batch_predictions)\n",
    "print(example_loss.shape)\n",
    "print(example_loss.numpy().mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/100\n",
      "172/172 [==============================] - 44s 258ms/step - loss: 2.6715\n",
      "Epoch 2/100\n",
      "172/172 [==============================] - 42s 247ms/step - loss: 2.0199\n",
      "Epoch 3/100\n",
      "172/172 [==============================] - 44s 253ms/step - loss: 1.8145\n",
      "Epoch 4/100\n",
      "172/172 [==============================] - 43s 249ms/step - loss: 1.6819\n",
      "Epoch 5/100\n",
      "172/172 [==============================] - 43s 252ms/step - loss: 1.5952\n",
      "Epoch 6/100\n",
      "172/172 [==============================] - 42s 243ms/step - loss: 1.5344\n",
      "Epoch 7/100\n",
      "172/172 [==============================] - 42s 242ms/step - loss: 1.4893\n",
      "Epoch 8/100\n",
      "172/172 [==============================] - 42s 244ms/step - loss: 1.4533\n",
      "Epoch 9/100\n",
      "172/172 [==============================] - 42s 242ms/step - loss: 1.4256\n",
      "Epoch 10/100\n",
      "172/172 [==============================] - 44s 255ms/step - loss: 1.3982\n",
      "Epoch 11/100\n",
      "172/172 [==============================] - 44s 255ms/step - loss: 1.3769\n",
      "Epoch 12/100\n",
      "172/172 [==============================] - 43s 250ms/step - loss: 1.3606\n",
      "Epoch 13/100\n",
      "172/172 [==============================] - 42s 246ms/step - loss: 1.3445\n",
      "Epoch 14/100\n",
      "172/172 [==============================] - 46s 266ms/step - loss: 1.3285\n",
      "Epoch 15/100\n",
      "172/172 [==============================] - 46s 265ms/step - loss: 1.3134\n",
      "Epoch 16/100\n",
      "172/172 [==============================] - 43s 249ms/step - loss: 1.2971\n",
      "Epoch 17/100\n",
      "172/172 [==============================] - 42s 245ms/step - loss: 1.2858\n",
      "Epoch 18/100\n",
      "172/172 [==============================] - 43s 252ms/step - loss: 1.2795\n",
      "Epoch 19/100\n",
      "172/172 [==============================] - 43s 248ms/step - loss: 1.2717\n",
      "Epoch 20/100\n",
      "172/172 [==============================] - 43s 249ms/step - loss: 1.2661\n",
      "Epoch 21/100\n",
      "172/172 [==============================] - 42s 247ms/step - loss: 1.2664\n",
      "Epoch 22/100\n",
      "172/172 [==============================] - 42s 244ms/step - loss: 1.2622\n",
      "Epoch 23/100\n",
      "172/172 [==============================] - 44s 257ms/step - loss: 1.2488\n",
      "Epoch 24/100\n",
      "172/172 [==============================] - 44s 255ms/step - loss: 1.2376\n",
      "Epoch 25/100\n",
      "172/172 [==============================] - 46s 267ms/step - loss: 1.2314\n",
      "Epoch 26/100\n",
      "172/172 [==============================] - 44s 257ms/step - loss: 1.2292\n",
      "Epoch 27/100\n",
      "172/172 [==============================] - 47s 275ms/step - loss: 1.2240\n",
      "Epoch 28/100\n",
      "172/172 [==============================] - 47s 271ms/step - loss: 1.2187\n",
      "Epoch 29/100\n",
      "172/172 [==============================] - 45s 260ms/step - loss: 1.2170\n",
      "Epoch 30/100\n",
      "172/172 [==============================] - 43s 250ms/step - loss: 1.2153\n",
      "Epoch 31/100\n",
      "172/172 [==============================] - 44s 256ms/step - loss: 1.2078\n",
      "Epoch 32/100\n",
      "172/172 [==============================] - 43s 252ms/step - loss: 1.1992\n",
      "Epoch 33/100\n",
      "172/172 [==============================] - 44s 256ms/step - loss: 1.1921\n",
      "Epoch 34/100\n",
      "172/172 [==============================] - 46s 265ms/step - loss: 1.1871\n",
      "Epoch 35/100\n",
      "172/172 [==============================] - 47s 273ms/step - loss: 1.1852\n",
      "Epoch 36/100\n",
      "172/172 [==============================] - 46s 270ms/step - loss: 1.1853\n",
      "Epoch 37/100\n",
      "172/172 [==============================] - 45s 260ms/step - loss: 1.1817\n",
      "Epoch 38/100\n",
      "172/172 [==============================] - 45s 259ms/step - loss: 1.1756\n",
      "Epoch 39/100\n",
      "172/172 [==============================] - 45s 262ms/step - loss: 1.1720\n",
      "Epoch 40/100\n",
      "172/172 [==============================] - 44s 256ms/step - loss: 1.1726\n",
      "Epoch 41/100\n",
      "172/172 [==============================] - 45s 264ms/step - loss: 1.1755\n",
      "Epoch 42/100\n",
      "172/172 [==============================] - 44s 256ms/step - loss: 1.1742\n",
      "Epoch 43/100\n",
      "172/172 [==============================] - 45s 259ms/step - loss: 1.1753\n",
      "Epoch 44/100\n",
      "172/172 [==============================] - 44s 255ms/step - loss: 1.1738\n",
      "Epoch 45/100\n",
      "172/172 [==============================] - 45s 262ms/step - loss: 1.1729\n",
      "Epoch 46/100\n",
      "172/172 [==============================] - 46s 266ms/step - loss: 1.1669\n",
      "Epoch 47/100\n",
      "172/172 [==============================] - 46s 270ms/step - loss: 1.1588\n",
      "Epoch 48/100\n",
      "172/172 [==============================] - 45s 264ms/step - loss: 1.1575\n",
      "Epoch 49/100\n",
      "172/172 [==============================] - 45s 260ms/step - loss: 1.1522\n",
      "Epoch 50/100\n",
      "172/172 [==============================] - 46s 266ms/step - loss: 1.1487\n",
      "Epoch 51/100\n",
      "172/172 [==============================] - 46s 269ms/step - loss: 1.1487\n",
      "Epoch 52/100\n",
      "172/172 [==============================] - 45s 263ms/step - loss: 1.1483\n",
      "Epoch 53/100\n",
      "172/172 [==============================] - 44s 255ms/step - loss: 1.1493\n",
      "Epoch 54/100\n",
      "172/172 [==============================] - 44s 256ms/step - loss: 1.1504\n",
      "Epoch 55/100\n",
      "172/172 [==============================] - 44s 253ms/step - loss: 1.1515\n",
      "Epoch 56/100\n",
      "172/172 [==============================] - 45s 260ms/step - loss: 1.1475\n",
      "Epoch 57/100\n",
      "172/172 [==============================] - 44s 256ms/step - loss: 1.1474\n",
      "Epoch 58/100\n",
      "172/172 [==============================] - 45s 262ms/step - loss: 1.1446\n",
      "Epoch 59/100\n",
      "172/172 [==============================] - 45s 262ms/step - loss: 1.1424\n",
      "Epoch 60/100\n",
      "172/172 [==============================] - 46s 269ms/step - loss: 1.1441\n",
      "Epoch 61/100\n",
      "172/172 [==============================] - 46s 267ms/step - loss: 1.1412\n",
      "Epoch 62/100\n",
      "172/172 [==============================] - 46s 270ms/step - loss: 1.1413\n",
      "Epoch 63/100\n",
      "172/172 [==============================] - 46s 265ms/step - loss: 1.1443\n",
      "Epoch 64/100\n",
      "172/172 [==============================] - 46s 266ms/step - loss: 1.1448\n",
      "Epoch 65/100\n",
      "172/172 [==============================] - 46s 265ms/step - loss: 1.1450\n",
      "Epoch 66/100\n",
      "172/172 [==============================] - 44s 258ms/step - loss: 1.1474\n",
      "Epoch 67/100\n",
      "172/172 [==============================] - 45s 264ms/step - loss: 1.1474\n",
      "Epoch 68/100\n",
      "172/172 [==============================] - 46s 267ms/step - loss: 1.1486\n",
      "Epoch 69/100\n",
      "172/172 [==============================] - 46s 268ms/step - loss: 1.1476\n",
      "Epoch 70/100\n",
      "172/172 [==============================] - 45s 264ms/step - loss: 1.1495\n",
      "Epoch 71/100\n",
      "172/172 [==============================] - 44s 258ms/step - loss: 1.1506\n",
      "Epoch 72/100\n",
      "172/172 [==============================] - 43s 248ms/step - loss: 1.1544\n",
      "Epoch 73/100\n",
      "172/172 [==============================] - 46s 268ms/step - loss: 1.1541\n",
      "Epoch 74/100\n",
      "172/172 [==============================] - 45s 262ms/step - loss: 1.1587\n",
      "Epoch 75/100\n",
      "172/172 [==============================] - 45s 260ms/step - loss: 1.1578\n",
      "Epoch 76/100\n",
      "172/172 [==============================] - 46s 268ms/step - loss: 1.1573\n",
      "Epoch 77/100\n",
      "172/172 [==============================] - 45s 263ms/step - loss: 1.1596\n",
      "Epoch 78/100\n",
      "172/172 [==============================] - 47s 270ms/step - loss: 1.1600\n",
      "Epoch 79/100\n",
      "172/172 [==============================] - 45s 264ms/step - loss: 1.1677\n",
      "Epoch 80/100\n",
      "172/172 [==============================] - 45s 263ms/step - loss: 1.1643\n",
      "Epoch 81/100\n",
      "172/172 [==============================] - 46s 266ms/step - loss: 1.1679\n",
      "Epoch 82/100\n",
      "172/172 [==============================] - 46s 267ms/step - loss: 1.1707\n",
      "Epoch 83/100\n",
      "172/172 [==============================] - 45s 260ms/step - loss: 1.1684\n",
      "Epoch 84/100\n",
      "172/172 [==============================] - 45s 264ms/step - loss: 1.1718\n",
      "Epoch 85/100\n",
      "172/172 [==============================] - 44s 256ms/step - loss: 1.1779\n",
      "Epoch 86/100\n",
      "172/172 [==============================] - 44s 257ms/step - loss: 1.1780\n",
      "Epoch 87/100\n",
      "172/172 [==============================] - 46s 265ms/step - loss: 1.1796\n",
      "Epoch 88/100\n",
      "172/172 [==============================] - 46s 270ms/step - loss: 1.1776\n",
      "Epoch 89/100\n",
      "172/172 [==============================] - 46s 268ms/step - loss: 1.1848\n",
      "Epoch 90/100\n",
      "172/172 [==============================] - 47s 271ms/step - loss: 1.1847\n",
      "Epoch 91/100\n",
      "172/172 [==============================] - 46s 265ms/step - loss: 1.1857\n",
      "Epoch 92/100\n",
      "172/172 [==============================] - 46s 266ms/step - loss: 1.1885\n",
      "Epoch 93/100\n",
      "172/172 [==============================] - 45s 263ms/step - loss: 1.1917\n",
      "Epoch 94/100\n",
      "172/172 [==============================] - 44s 258ms/step - loss: 1.1908\n",
      "Epoch 95/100\n",
      "172/172 [==============================] - 43s 252ms/step - loss: 1.1998\n",
      "Epoch 96/100\n",
      "172/172 [==============================] - 46s 270ms/step - loss: 1.2035\n",
      "Epoch 97/100\n",
      "172/172 [==============================] - 46s 267ms/step - loss: 1.2057\n",
      "Epoch 98/100\n",
      "172/172 [==============================] - 44s 258ms/step - loss: 1.2016\n",
      "Epoch 99/100\n",
      "172/172 [==============================] - 44s 256ms/step - loss: 1.2069\n",
      "Epoch 100/100\n",
      "172/172 [==============================] - 46s 267ms/step - loss: 1.2162\n"
     ]
    }
   ],
   "source": [
    "output_dir = \"./text_generation_checkpoints\"\n",
    "if not os.path.exists(output_dir):\n",
    "    os.mkdir(output_dir)\n",
    "checkpoint_prefix = os.path.join(output_dir, 'ckpt_{epoch}')\n",
    "checkpoint_callback = keras.callbacks.ModelCheckpoint(\n",
    "    filepath = checkpoint_prefix,\n",
    "    save_weights_only = True)\n",
    "\n",
    "epochs = 100\n",
    "history = model.fit(seq_dataset, epochs = epochs,\n",
    "                    callbacks = [checkpoint_callback])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'./text_generation_checkpoints/ckpt_100'"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.train.latest_checkpoint(output_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential_1\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "embedding_1 (Embedding)      (1, None, 256)            16640     \n",
      "_________________________________________________________________\n",
      "simple_rnn_1 (SimpleRNN)     (1, None, 1024)           1311744   \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (1, None, 65)             66625     \n",
      "=================================================================\n",
      "Total params: 1,395,009\n",
      "Trainable params: 1,395,009\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model2 = build_model(vocab_size,\n",
    "                     embedding_dim,\n",
    "                     rnn_units,\n",
    "                     batch_size = 1)\n",
    "model2.load_weights(tf.train.latest_checkpoint(output_dir))\n",
    "model2.build(tf.TensorShape([1, None]))\n",
    "# start ch sequence A,\n",
    "# A -> model -> b\n",
    "# A.append(b) -> B\n",
    "# B(Ab) -> model -> c\n",
    "# B.append(c) -> C\n",
    "# C(Abc) -> model -> ...\n",
    "model2.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All: his face beauty's slanders. Well, and some unto the storm'd fear; then dreams.\n",
      "\n",
      "CORIOLANUS:\n",
      "I will be made;\n",
      "Rike unwilling I will purposed-head:\n",
      "And I will stir the man\n",
      "And all the from your parciur flies by himself. Well;\n",
      "I am the widows fair? s, Auch\n",
      "An utterance on many lines, I'ld s, and thou never ed.\n",
      "Some happy stride, and supll you want fairest twenty soled no geedle moise his own alone! I am toUS:\n",
      "Hair undood oppasies? Or Your issues sir,\n",
      "That lies alenty; and therefore from your wealth.\n",
      "\n",
      "ROMEO:\n",
      "As I beleavent an England.\n",
      "\n",
      "TYBALT:\n",
      "My love: be of thee the flesh long-prounds. Grace now, gentle more than Jove, my gracious souls,\n",
      "Brief, his father,\n",
      "As, by some instery assurakery live,\n",
      "And finderape.\n",
      "\n",
      "CARESLENT:\n",
      "My lord of France?\n",
      "Why, what I had sees to Rome, name myself; it wass'd, being all eet and being sudd it yet?\n",
      "\n",
      "RAMILANUS:\n",
      "Now, brother Clarence.\n",
      "Of all run arms\n",
      "By Bianciau-house your grandsistail.\n",
      "\n",
      "PETRUCHIO:\n",
      "Come, content ines, it is here?\n",
      "This day, my lord, will you thing\n"
     ]
    }
   ],
   "source": [
    "def generate_text(model, start_string, num_generate = 1000):\n",
    "    input_eval = [char2idx[ch] for ch in start_string]\n",
    "    input_eval = tf.expand_dims(input_eval, 0)\n",
    "    \n",
    "    text_generated = []\n",
    "    model.reset_states()\n",
    "    \n",
    "    for _ in range(num_generate):\n",
    "        # 1. model inference -> predictions\n",
    "        # 2. sample -> ch -> text_generated.\n",
    "        # 3. update input_eval\n",
    "        \n",
    "        # predictions : [batch_size, input_eval_len, vocab_size]\n",
    "        predictions = model(input_eval)\n",
    "        # predictions : [input_eval_len, vocab_size]\n",
    "        predictions = tf.squeeze(predictions, 0)\n",
    "        # predicted_ids: [input_eval_len, 1]\n",
    "        # a b c -> b c d\n",
    "        predicted_id = tf.random.categorical(\n",
    "            predictions, num_samples = 1)[-1, 0].numpy()\n",
    "        text_generated.append(idx2char[predicted_id])\n",
    "        # s, x -> rnn -> s', y\n",
    "        input_eval = tf.expand_dims([predicted_id], 0)\n",
    "    return start_string + ''.join(text_generated)\n",
    "\n",
    "new_text = generate_text(model2, \"All: \")\n",
    "print(new_text)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
